{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Decision tree for classification in plain Python\n",
    "\n",
    "\n",
    "A decision tree is a **supervised** machine learning model that can be used both for **classification** and **regression**. At its core, a decision tree uses a tree structure to predict an output value for a given input example. In the tree, each path from the root node to a leaf node represents a *decision path* that ends in a predicted value. \n",
    "\n",
    "A simple example might look as follows:\n",
    "![caption](figures/decision_tree.png)\n",
    "\n",
    "Decision trees have many advantages. For example, they are easy to understand and their decisions are easy to interpret. Also, they don't require a lot of data preparation. A more extensive list of their advantages and disadvantages can be found [here](http://scikit-learn.org/stable/modules/tree.html).\n",
    "\n",
    "### CART training algorithm \n",
    "In order to train a decision tree, various algorithms can be used. In this notebook we will focus on the *CART* algorithm (Classification and Regression Trees) for *classification*. The CART algorithm builds a *binary tree* in which every non-leaf node has exactly two children (corresponding to a yes/no answer). \n",
    "\n",
    "Given a set of training examples and their labels, the algorithm repeatedly splits the training examples $D$ into two subsets $D_{left}, D_{right}$ using some feature $f$ and feature threshold $t_f$ such that samples with the same label are grouped together. At each node, the algorithm selects the split $\\theta = (f, t_f)$ that produces the purest subsets, weighted by their size. Purity/impurity is measured using the *Gini impurity*.\n",
    "\n",
    "So at each step, the algorithm selects the parameters $\\theta$ that minimize the following cost function:\n",
    "\n",
    "\\begin{equation}\n",
    "J(D, \\theta) = \\frac{n_{left}}{n_{total}} G_{left} + \\frac{n_{right}}{n_{total}} G_{right}\n",
    "\\end{equation}\n",
    "\n",
    "- $D$: remaining training examples   \n",
    "- $n_{total}$ : number of remaining training examples\n",
    "- $\\theta = (f, t_f)$: feature and feature threshold\n",
    "- $n_{left}/n_{right}$: number of samples in the left/right subset\n",
    "- $G_{left}/G_{right}$: Gini impurity of the left/right subset\n",
    "\n",
    "This step is repeated recursively until the *maximum allowable depth* is reached or the current number of samples $n_{total}$ drops below some minimum number. The original equations can be found [here](http://scikit-learn.org/stable/modules/tree.html).\n",
    "\n",
    "After building the tree, new examples can be classified by navigating through the tree, testing at each node the corresponding feature until a leaf node/prediction is reached.\n",
    "\n",
    "\n",
    "### Gini Impurity\n",
    "\n",
    "Given $K$ different classification values $k \\in \\{1, ..., K\\}$ the Gini impurity of node $m$ is computed as follows:\n",
    "\n",
    "\\begin{equation}\n",
    "G_m = 1 - \\sum_{k=1}^{K} (p_{m,k})^2\n",
    "\\end{equation}\n",
    "\n",
    "where $p_{m, k}$ is the fraction of training examples with class $k$ among all training examples in node $m$.\n",
    "\n",
    "The Gini impurity can be used to evaluate how good a potential split is. A split divides a given set of training examples into two groups. Gini measures how \"mixed\" the resulting groups are. A perfect separation (i.e. each group contains only samples of the same class) corresponds to a Gini impurity of 0. If the resulting groups contain equally many samples of each class, the Gini impurity will reach its highest value of 0.5\n",
    "\n",
    "### Caveats\n",
    "\n",
    "Without regularization, decision trees are likely to overfit the training examples. This can be prevented using techniques like *pruning* or by providing a maximum allowed tree depth and/or a minimum number of samples required to split a node further."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-04-09T12:50:18.442511Z",
     "start_time": "2018-04-09T12:50:18.426963Z"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_iris\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.model_selection import train_test_split\n",
    "np.random.seed(123)\n",
    "\n",
    "% matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset\n",
    "\n",
    "The iris dataset compromises 150 examples of 3 different species of iris flowers (Setosa, Versicolour, and Virginica). Each example is described by four attributes: sepal length (cm), sepal width (cm), petal length (cm), petal width (cm)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-04-09T12:50:18.959842Z",
     "start_time": "2018-04-09T12:50:18.951915Z"
    }
   },
   "outputs": [],
   "source": [
    "iris = load_iris()\n",
    "\n",
    "X, y = iris.data, iris.target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-04-09T12:50:19.202874Z",
     "start_time": "2018-04-09T12:50:19.188064Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shape X_train: (112, 4)\n",
      "Shape y_train: (112,)\n",
      "Shape X_test: (38, 4)\n",
      "Shape y_test: (38,)\n"
     ]
    }
   ],
   "source": [
    "# Split the data into a training and test set\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y)\n",
    "\n",
    "print(f'Shape X_train: {X_train.shape}')\n",
    "print(f'Shape y_train: {y_train.shape}')\n",
    "print(f'Shape X_test: {X_test.shape}')\n",
    "print(f'Shape y_test: {y_test.shape}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Decision tree class\n",
    "\n",
    "Parts of this code were inspired by [this tutorial](https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-04-09T12:50:20.706089Z",
     "start_time": "2018-04-09T12:50:20.093228Z"
    }
   },
   "outputs": [],
   "source": [
    "class DecisionTree:\n",
    "    \"\"\"\n",
    "    Decision tree for classification\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        self.root_dict = None\n",
    "        self.tree_dict = None\n",
    "\n",
    "    def split_dataset(self, X, y, feature_idx, threshold):\n",
    "        \"\"\"\n",
    "        Splits dataset X into two subsets, according to a given feature\n",
    "        and feature threshold.\n",
    "\n",
    "        Args:\n",
    "            X: 2D numpy array with data samples\n",
    "            y: 1D numpy array with labels\n",
    "            feature_idx: int, index of feature used for splitting the data\n",
    "            threshold: float, threshold used for splitting the data\n",
    "\n",
    "        Returns:\n",
    "            splits: dict containing the left and right subsets\n",
    "            and their labels\n",
    "        \"\"\"\n",
    "\n",
    "        left_idx = np.where(X[:, feature_idx] < threshold)\n",
    "        right_idx = np.where(X[:, feature_idx] >= threshold)\n",
    "\n",
    "        left_subset = X[left_idx]\n",
    "        y_left = y[left_idx]\n",
    "\n",
    "        right_subset = X[right_idx]\n",
    "        y_right = y[right_idx]\n",
    "\n",
    "        splits = {\n",
    "        'left': left_subset,\n",
    "        'y_left': y_left,\n",
    "        'right': right_subset,\n",
    "        'y_right': y_right,\n",
    "        }\n",
    "\n",
    "        return splits\n",
    "\n",
    "    def gini_impurity(self, y_left, y_right, n_left, n_right):\n",
    "        \"\"\"\n",
    "        Computes Gini impurity of a split.\n",
    "\n",
    "        Args:\n",
    "            y_left, y_right: target values of samples in left/right subset\n",
    "            n_left, n_right: number of samples in left/right subset\n",
    "\n",
    "        Returns:\n",
    "            gini_left: float, Gini impurity of left subset\n",
    "            gini_right: gloat, Gini impurity of right subset\n",
    "        \"\"\"\n",
    "\n",
    "        n_total = n_left + n_left\n",
    "\n",
    "        score_left, score_right = 0, 0\n",
    "        gini_left, gini_right = 0, 0\n",
    "\n",
    "        if n_left != 0:\n",
    "            for c in range(self.n_classes):\n",
    "                # For each class c, compute fraction of samples with class c\n",
    "                p_left = len(np.where(y_left == c)[0]) / n_left\n",
    "                score_left += p_left * p_left\n",
    "            gini_left = 1 - score_left\n",
    "\n",
    "        if n_right != 0:\n",
    "            for c in range(self.n_classes):\n",
    "                p_right = len(np.where(y_right == c)[0]) / n_right\n",
    "                score_right += p_right * p_right\n",
    "            gini_right = 1 - score_right\n",
    "\n",
    "        return gini_left, gini_right\n",
    "\n",
    "    def get_cost(self, splits):\n",
    "        \"\"\"\n",
    "        Computes cost of a split given the Gini impurity of\n",
    "        the left and right subset and the sizes of the subsets\n",
    "        \n",
    "        Args:\n",
    "            splits: dict, containing params of current split\n",
    "        \"\"\"\n",
    "        y_left = splits['y_left']\n",
    "        y_right = splits['y_right']\n",
    "\n",
    "        n_left = len(y_left)\n",
    "        n_right = len(y_right)\n",
    "        n_total = n_left + n_right\n",
    "\n",
    "        gini_left, gini_right = self.gini_impurity(y_left, y_right, n_left, n_right)\n",
    "        cost = (n_left / n_total) * gini_left + (n_right / n_total) * gini_right\n",
    "\n",
    "        return cost\n",
    "\n",
    "    def find_best_split(self, X, y):\n",
    "        \"\"\"\n",
    "        Finds the best feature and feature index to split dataset X into\n",
    "        two groups. Checks every value of every attribute as a candidate\n",
    "        split.\n",
    "\n",
    "        Args:\n",
    "            X: 2D numpy array with data samples\n",
    "            y: 1D numpy array with labels\n",
    "\n",
    "        Returns:\n",
    "            best_split_params: dict, containing parameters of the best split\n",
    "        \"\"\"\n",
    "\n",
    "        n_samples, n_features = X.shape\n",
    "\n",
    "        best_feature_idx, best_threshold, best_cost, best_splits = np.inf, np.inf, np.inf, None\n",
    "\n",
    "        for feature_idx in range(n_features):\n",
    "            for i in range(n_samples):\n",
    "                current_sample = X[i]\n",
    "                threshold = current_sample[feature_idx]\n",
    "                splits = self.split_dataset(X, y, feature_idx, threshold)\n",
    "                cost = self.get_cost(splits)\n",
    "\n",
    "                if cost < best_cost:\n",
    "                    best_feature_idx = feature_idx\n",
    "                    best_threshold = threshold\n",
    "                    best_cost = cost\n",
    "                    best_splits = splits\n",
    "\n",
    "        best_split_params = {\n",
    "            'feature_idx': best_feature_idx,\n",
    "            'threshold': best_threshold,\n",
    "            'cost': best_cost,\n",
    "            'left': best_splits['left'],\n",
    "            'y_left': best_splits['y_left'],\n",
    "            'right': best_splits['right'],\n",
    "            'y_right': best_splits['y_right'],\n",
    "        }\n",
    "\n",
    "        return best_split_params\n",
    "\n",
    "\n",
    "    def build_tree(self, node_dict, depth, max_depth, min_samples):\n",
    "        \"\"\"\n",
    "        Builds the decision tree in a recursive fashion.\n",
    "\n",
    "        Args:\n",
    "            node_dict: dict, representing the current node\n",
    "            depth: int, depth of current node in the tree\n",
    "            max_depth: int, maximum allowed tree depth\n",
    "            min_samples: int, minimum number of samples needed to split a node further\n",
    "\n",
    "        Returns:\n",
    "            node_dict: dict, representing the full subtree originating from current node\n",
    "        \"\"\"\n",
    "        left_samples = node_dict['left']\n",
    "        right_samples = node_dict['right']\n",
    "        y_left_samples = node_dict['y_left']\n",
    "        y_right_samples = node_dict['y_right']\n",
    "\n",
    "        if len(y_left_samples) == 0 or len(y_right_samples) == 0:\n",
    "            node_dict[\"left_child\"] = node_dict[\"right_child\"] = self.create_terminal_node(np.append(y_left_samples, y_right_samples))\n",
    "            return None\n",
    "\n",
    "        if depth >= max_depth:\n",
    "            node_dict[\"left_child\"] = self.create_terminal_node(y_left_samples)\n",
    "            node_dict[\"right_child\"] = self.create_terminal_node(y_right_samples)\n",
    "            return None\n",
    "\n",
    "        if len(right_samples) < min_samples:\n",
    "            node_dict[\"right_child\"] = self.create_terminal_node(y_right_samples)\n",
    "        else:\n",
    "            node_dict[\"right_child\"] = self.find_best_split(right_samples, y_right_samples)\n",
    "            self.build_tree(node_dict[\"right_child\"], depth+1, max_depth, min_samples)\n",
    "\n",
    "        if len(left_samples) < min_samples:\n",
    "            node_dict[\"left_child\"] = self.create_terminal_node(y_left_samples)\n",
    "        else:\n",
    "            node_dict[\"left_child\"] = self.find_best_split(left_samples, y_left_samples)\n",
    "            self.build_tree(node_dict[\"left_child\"], depth+1, max_depth, min_samples)\n",
    "\n",
    "        return node_dict\n",
    "\n",
    "    def create_terminal_node(self, y):\n",
    "        \"\"\"\n",
    "        Creates a terminal node.\n",
    "        Given a set of labels the most common label is computed and\n",
    "        set as the classification value of the node.\n",
    "\n",
    "        Args:\n",
    "            y: 1D numpy array with labels\n",
    "        Returns:\n",
    "            classification: int, predicted class\n",
    "        \"\"\"\n",
    "        classification = max(set(y), key=list(y).count)\n",
    "        return classification\n",
    "\n",
    "    def train(self, X, y, max_depth, min_samples):\n",
    "        \"\"\"\n",
    "        Fits decision tree on a given dataset.\n",
    "\n",
    "        Args:\n",
    "            X: 2D numpy array with data samples\n",
    "            y: 1D numpy array with labels\n",
    "            max_depth: int, maximum allowed tree depth\n",
    "            min_samples: int, minimum number of samples needed to split a node further\n",
    "        \"\"\"\n",
    "        self.n_classes = len(set(y))\n",
    "        self.root_dict = self.find_best_split(X, y)\n",
    "        self.tree_dict = self.build_tree(self.root_dict, 1, max_depth, min_samples)\n",
    "\n",
    "    def predict(self, X, node):\n",
    "        \"\"\"\n",
    "        Predicts the class for a given input example X.\n",
    "\n",
    "        Args:\n",
    "            X: 1D numpy array, input example\n",
    "            node: dict, representing trained decision tree\n",
    "\n",
    "        Returns:\n",
    "            prediction: int, predicted class\n",
    "        \"\"\"\n",
    "        feature_idx = node['feature_idx']\n",
    "        threshold = node['threshold']\n",
    "\n",
    "        if X[feature_idx] < threshold:\n",
    "            if isinstance(node['left_child'], (int, np.integer)):\n",
    "                return node['left_child']\n",
    "            else:\n",
    "                prediction = self.predict(X, node['left_child'])\n",
    "        elif X[feature_idx] >= threshold:\n",
    "            if isinstance(node['right_child'], (int, np.integer)):\n",
    "                return node['right_child']\n",
    "            else:\n",
    "                prediction = self.predict(X, node['right_child'])\n",
    "\n",
    "        return prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-04-09T12:19:30.968801Z",
     "start_time": "2018-04-09T12:19:30.963691Z"
    }
   },
   "source": [
    "## Initializing and training the decision tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-04-09T12:50:21.291025Z",
     "start_time": "2018-04-09T12:50:21.235025Z"
    }
   },
   "outputs": [],
   "source": [
    "tree = DecisionTree()\n",
    "tree.train(X_train, y_train, max_depth=2, min_samples=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Printing the decision tree structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-04-09T12:50:24.127183Z",
     "start_time": "2018-04-09T12:50:24.115505Z"
    }
   },
   "outputs": [],
   "source": [
    "def print_tree(node, depth=0):\n",
    "    if isinstance(node, (int, np.integer)):\n",
    "        print(f\"{depth * '  '}predicted class: {iris.target_names[node]}\")\n",
    "    else:\n",
    "        print(f\"{depth * '  '}{iris.feature_names[node['feature_idx']]} < {node['threshold']}, \"\n",
    "             f\"cost of split: {round(node['cost'], 3)}\")\n",
    "        print_tree(node[\"left_child\"], depth+1)\n",
    "        print_tree(node[\"right_child\"], depth+1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-04-09T12:50:24.818384Z",
     "start_time": "2018-04-09T12:50:24.809082Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "petal length (cm) < 3.0, cost of split: 0.346\n",
      "  sepal length (cm) < 5.4, cost of split: 0.0\n",
      "    predicted class: setosa\n",
      "    predicted class: setosa\n",
      "  petal width (cm) < 1.8, cost of split: 0.097\n",
      "    predicted class: versicolor\n",
      "    predicted class: virginica\n"
     ]
    }
   ],
   "source": [
    "print_tree(tree.tree_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing the decision tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-04-09T12:50:28.307583Z",
     "start_time": "2018-04-09T12:50:28.284228Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy on test set: 0.9473684210526315\n"
     ]
    }
   ],
   "source": [
    "all_predictions = []\n",
    "for i in range(X_test.shape[0]):\n",
    "    result = tree.predict(X_test[i], tree.tree_dict)\n",
    "    all_predictions.append(y_test[i] == result)\n",
    "\n",
    "print(f\"Accuracy on test set: {sum(all_predictions) / len(all_predictions)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
